# author zybz 2022.3.16
# Step1 划分数据集
import os
import random


class get_imggeSets():
    """蒋images划分为训练集填充ImageSets文件夹,该类生成后即运行"""
    def __init__(self,
                 xml_file_path='../Annotations',
                 txt_save_path='../ImageSets',
                 trainval_percent=0.2,
                 train_percent=0.8):
        """蒋images划分为训练集填充ImageSets文件夹

        Args:
            xml_file_path (str, optional): _description_. Defaults to '../Annotations'.
            txt_save_path (str, optional): _description_. Defaults to '../ImageSets'.
            trainval_percent (float, optional): _description_. Defaults to 0.2.
            train_percent (float, optional): _description_. Defaults to 0.8.
        """
        self.trainval_percent = trainval_percent
        self.train_percent = train_percent
        self.xml_file_path = xml_file_path
        self.txt_save_path = txt_save_path
        self.run()

    def run(self):
        """执行后在ImageSets文件夹中生成训练集、验证集、测试集三个txt文件
        """
        total_xml = os.listdir(self.xml_file_path)
        num = len(total_xml)
        list = range(num)
        tv = int(num * self.trainval_percent)
        tr = int(tv * self.train_percent)
        trainval = random.sample(list, tv)
        train = random.sample(trainval, tr)

        ftrainval = open('../ImageSets/trainval.txt', 'w')
        ftest = open('../ImageSets/test.txt', 'w')
        ftrain = open('../ImageSets/train.txt', 'w')
        fval = open('../ImageSets/val.txt', 'w')

        for i in list:
            name = total_xml[i][:-4] + '\n'
            if i in trainval:
                ftrainval.write(name)
                if i in train:
                    ftest.write(name)
                else:
                    fval.write(name)
            else:
                ftrain.write(name)

        ftrainval.close()
        ftrain.close()
        fval.close()
        ftest.close()
        print("Step 1 make imageSets successfully")


if __name__ == "__main__":
    imagesets = get_imggeSets()
